# Copyright 1999-2020 Alibaba Group Holding Ltd.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#      http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

import pickle
from typing import Dict

import numpy as np

from ... import opcodes as OperandDef
from ...serialize import ValueType, KeyField, StringField, BytesField, TupleField
from ...lib.filesystem import get_fs, FSMap
from ...tiles import TilesError
from ...utils import check_chunks_unknown_shape
from .core import TensorDataStore


class ZarrOptions(object):
    def __init__(self, options: Dict):
        self._options = options

    def todict(self):
        return self._options

    @staticmethod
    def _stringfy(v):
        return pickle.dumps(v) if not isinstance(v, str) else v

    def __mars_tokenize__(self):
        return list(self._options.keys()), \
               list(self._stringfy(v) for v in self._options.values())

    def __getstate__(self):
        return self._options

    def __setstate__(self, state):
        self._options = state


class TensorToZarrDataStore(TensorDataStore):
    _op_type_ = OperandDef.TENSOR_STORE_ZARR

    _input = KeyField('input')
    _path = StringField('path')
    _group = StringField('group')
    _dataset = StringField('dataset')
    _zarr_options = BytesField('zarr_options', on_serialize=pickle.dumps,
                               on_deserialize=pickle.loads)
    _axis_offsets = TupleField('axis_offsets', ValueType.int32)

    def __init__(self, path=None, group=None, dataset=None, zarr_options=None,
                 axis_offsets=None, **kw):
        super().__init__(_path=path, _group=group, _dataset=dataset,
                         _zarr_options=zarr_options, _axis_offsets=axis_offsets, **kw)

    @property
    def path(self):
        return self._path

    @property
    def group(self):
        return self._group

    @property
    def dataset(self):
        return self._dataset

    @property
    def zarr_options(self):
        return self._zarr_options

    @property
    def axis_offsets(self):
        return self._axis_offsets

    def _set_inputs(self, inputs):
        super()._set_inputs(inputs)
        self._input = self._inputs[0]

    @classmethod
    def tile(cls, op):
        import zarr

        check_chunks_unknown_shape(op.inputs, TilesError)
        in_tensor = op.input

        # create dataset
        fs = get_fs(op.path, None)
        path = op.path
        if op.group is not None:
            path += '/' + op.group
        fs_map = FSMap(path, fs)
        zarr.open(fs_map, 'w', path=op.dataset,
                  dtype=in_tensor.dtype, shape=in_tensor.shape,
                  chunks=tuple(max(ns) for ns in in_tensor.nsplits),
                  **op.zarr_options.todict())

        cum_nsplits = [[0] + np.cumsum(ns).tolist() for ns in in_tensor.nsplits]
        out_chunks = []
        for chunk in in_tensor.chunks:
            chunk_op = op.copy().reset_key()
            chunk_op._axis_offsets = \
                tuple(cs[i] for i, cs in zip(chunk.index, cum_nsplits))
            out_chunks.append(chunk_op.new_chunk([chunk], shape=(0,) * chunk.ndim,
                                                 index=chunk.index))

        new_op = op.copy()
        out = op.outputs[0]
        nsplits = tuple((0,) * len(ns) for ns in in_tensor.nsplits)
        return new_op.new_tensors(op.inputs, shape=out.shape, order=out.order,
                                  nsplits=nsplits, chunks=out_chunks)

    @classmethod
    def execute(cls, ctx, op):
        import zarr

        fs = get_fs(op.path, None)
        fs_map = FSMap(op.path, fs)

        group = zarr.Group(store=fs_map, path=op.group)
        array = group[op.dataset]

        to_store = ctx[op.inputs[0].key]
        axis_offsets = op.axis_offsets
        shape = to_store.shape

        array[tuple(slice(offset, offset + size)
              for offset, size
              in zip(axis_offsets, shape))] = to_store

        ctx[op.outputs[0].key] = np.empty((0,) * to_store.ndim,
                                          dtype=to_store.dtype)


def tozarr(path, x, group=None, dataset=None, **zarr_options):
    import zarr

    if isinstance(path, zarr.Array):
        arr = path
        if isinstance(arr.store, FSMap):
            root = arr.store.root
            path, dataset = root.rsplit('/', 1)
        else:
            path = arr.store.path
            if '/' in arr.path and group is None:
                group = arr.path.rsplit('/', 1)[0]
            dataset = arr.basename
            if not dataset:
                path, dataset = path.rsplit('/', 1)
        for attr in ['compressor', 'filters']:
            if getattr(arr, attr):
                zarr_options[attr] = getattr(arr, attr)
    elif isinstance(path, str):
        if dataset is None:
            path, dataset = path.rsplit('/', 1)
    else:
        raise TypeError('`path` passed has wrong type, '
                        'expect str, or zarr.Array'
                        f'got {type(path)}')

    op = TensorToZarrDataStore(path=path, group=group, dataset=dataset,
                               zarr_options=ZarrOptions(zarr_options))
    return op(x)
